/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved.
 */

package org.postgresql.readwritesplitting;

import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.SQLWarning;
import java.sql.Statement;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;

/**
 * Read write splitting statement.
 *
 * @since 2023-11-20
 */
public class ReadWriteSplittingPgStatement implements Statement {
    private final List<Statement> statements = new LinkedList<>();

    private final ForceExecuteTemplate<Statement> forceExecuteTemplate = new ForceExecuteTemplate<>();

    private final ReadWriteSplittingPgConnection readWriteSplittingPgConnection;

    private final Integer resultSetType;

    private final Integer resultSetConcurrency;

    private final Integer resultSetHoldability;

    private Statement currentStatement;

    private ResultSet currentResultSet;

    private boolean isClosed;

    /**
     * Constructor.
     *
     * @param readWriteSplittingPgConnection read write splitting connection
     * @param resultSetType result set type
     * @param resultSetConcurrency result set concurrency
     * @param resultSetHoldability result set holdability
     */
    public ReadWriteSplittingPgStatement(ReadWriteSplittingPgConnection readWriteSplittingPgConnection,
                                         int resultSetType, int resultSetConcurrency, int resultSetHoldability) {
        this.readWriteSplittingPgConnection = readWriteSplittingPgConnection;
        this.resultSetType = resultSetType;
        this.resultSetConcurrency = resultSetConcurrency;
        this.resultSetHoldability = resultSetHoldability;
    }

    @Override
    public ResultSet executeQuery(String sql) throws SQLException {
        Statement pgStatement = createPgStatement(sql);
        ResultSet result = pgStatement.executeQuery(sql);
        currentResultSet = result;
        return result;
    }

    private Statement createPgStatement(String sql) throws SQLException {
        Connection connection = SqlRouteEngine.getRoutedConnection(readWriteSplittingPgConnection, sql);
        Statement statement = connection.createStatement(resultSetType, resultSetConcurrency, resultSetHoldability);
        statements.add(statement);
        currentStatement = statement;
        return statement;
    }

    /**
     * Get current result set.
     *
     * @return current result set
     * @throws SQLException SQL exception
     */
    public Statement getCurrentStatement() throws SQLException {
        if (currentStatement == null) {
            Statement statement =
                    readWriteSplittingPgConnection.getConnectionManager().getCurrentConnection().createStatement();
            statements.add(statement);
            currentStatement = statement;
            return statement;
        } else {
            return currentStatement;
        }
    }

    @Override
    public boolean execute(String sql) throws SQLException {
        Statement pgStatement = createPgStatement(sql);
        return pgStatement.execute(sql);
    }

    @Override
    public boolean execute(String sql, int autoGeneratedKeys) throws SQLException {
        Statement pgStatement = createPgStatement(sql);
        return pgStatement.execute(sql, autoGeneratedKeys);
    }

    @Override
    public boolean execute(String sql, int[] columnIndexes) throws SQLException {
        Statement pgStatement = createPgStatement(sql);
        return pgStatement.execute(sql, columnIndexes);
    }

    @Override
    public boolean execute(String sql, String[] columnNames) throws SQLException {
        Statement pgStatement = createPgStatement(sql);
        return pgStatement.execute(sql, columnNames);
    }

    @Override
    public int executeUpdate(String sql) throws SQLException {
        Statement pgStatement = createPgStatement(sql);
        return pgStatement.executeUpdate(sql);
    }

    @Override
    public int executeUpdate(String sql, int autoGeneratedKeys) throws SQLException {
        Statement pgStatement = createPgStatement(sql);
        return pgStatement.executeUpdate(sql, autoGeneratedKeys);
    }

    @Override
    public int executeUpdate(String sql, int[] columnIndexes) throws SQLException {
        Statement pgStatement = createPgStatement(sql);
        return pgStatement.executeUpdate(sql, columnIndexes);
    }

    @Override
    public int executeUpdate(String sql, String[] columnNames) throws SQLException {
        Statement pgStatement = createPgStatement(sql);
        return pgStatement.executeUpdate(sql, columnNames);
    }

    public Collection<Statement> getRoutedStatements() {
        return statements;
    }

    @Override
    public void close() throws SQLException {
        isClosed = true;
        try {
            forceExecuteTemplate.execute(getRoutedStatements(), Statement::close);
        } finally {
            getRoutedStatements().clear();
        }
    }

    @Override
    public boolean isClosed() throws SQLException {
        return isClosed;
    }

    @Override
    public ResultSet getResultSet() throws SQLException {
        return currentResultSet;
    }

    @Override
    public int[] executeBatch() throws SQLException {
        return getCurrentStatement().executeBatch();
    }

    @Override
    public int getMaxFieldSize() throws SQLException {
        return getCurrentStatement().getMaxFieldSize();
    }

    @Override
    public void setMaxFieldSize(int max) throws SQLException {
        getCurrentStatement().setMaxFieldSize(max);
    }

    @Override
    public int getMaxRows() throws SQLException {
        return getCurrentStatement().getMaxRows();
    }

    @Override
    public void setMaxRows(int max) throws SQLException {
        getCurrentStatement().setMaxRows(max);
    }

    @Override
    public void setEscapeProcessing(boolean isEnabled) throws SQLException {
        getCurrentStatement().setEscapeProcessing(isEnabled);
    }

    @Override
    public int getQueryTimeout() throws SQLException {
        return getCurrentStatement().getQueryTimeout();
    }

    @Override
    public void setQueryTimeout(int seconds) throws SQLException {
        getCurrentStatement().setQueryTimeout(seconds);
    }

    @Override
    public void cancel() throws SQLException {
        forceExecuteTemplate.execute(getRoutedStatements(), Statement::cancel);
    }

    @Override
    public SQLWarning getWarnings() throws SQLException {
        return getCurrentStatement().getWarnings();
    }

    @Override
    public void clearWarnings() throws SQLException {
        getCurrentStatement().clearWarnings();
    }

    @Override
    public void setCursorName(String name) throws SQLException {
        getCurrentStatement().setCursorName(name);
    }

    @Override
    public int getUpdateCount() throws SQLException {
        return getCurrentStatement().getUpdateCount();
    }

    @Override
    public boolean getMoreResults() throws SQLException {
        return getCurrentStatement().getMoreResults();
    }

    @Override
    public void setFetchDirection(int direction) throws SQLException {
        getCurrentStatement().setFetchDirection(direction);
    }

    @Override
    public int getFetchDirection() throws SQLException {
        return getCurrentStatement().getFetchDirection();
    }

    @Override
    public void setFetchSize(int rows) throws SQLException {
        getCurrentStatement().setFetchSize(rows);
    }

    @Override
    public int getFetchSize() throws SQLException {
        return getCurrentStatement().getFetchSize();
    }

    @Override
    public int getResultSetConcurrency() throws SQLException {
        return getCurrentStatement().getResultSetConcurrency();
    }

    @Override
    public int getResultSetType() throws SQLException {
        return getCurrentStatement().getResultSetType();
    }

    @Override
    public void addBatch(String sql) throws SQLException {
        getCurrentStatement().addBatch(sql);
    }

    @Override
    public void clearBatch() throws SQLException {
        getCurrentStatement().clearBatch();
    }

    @Override
    public Connection getConnection() throws SQLException {
        return readWriteSplittingPgConnection;
    }

    @Override
    public boolean getMoreResults(int current) throws SQLException {
        return getCurrentStatement().getMoreResults();
    }

    @Override
    public ResultSet getGeneratedKeys() throws SQLException {
        return getCurrentStatement().getGeneratedKeys();
    }

    @Override
    public int getResultSetHoldability() throws SQLException {
        return getCurrentStatement().getResultSetHoldability();
    }

    @Override
    public void setPoolable(boolean isPoolable) throws SQLException {
        getCurrentStatement().setPoolable(isPoolable);
    }

    @Override
    public boolean isPoolable() throws SQLException {
        return getCurrentStatement().isPoolable();
    }

    @Override
    public void closeOnCompletion() throws SQLException {
        getCurrentStatement().closeOnCompletion();
    }

    @Override
    public boolean isCloseOnCompletion() throws SQLException {
        return getCurrentStatement().isCloseOnCompletion();
    }

    @Override
    public <T> T unwrap(Class<T> iface) throws SQLException {
        return getCurrentStatement().unwrap(iface);
    }

    @Override
    public boolean isWrapperFor(Class<?> iface) throws SQLException {
        return getCurrentStatement().isWrapperFor(iface);
    }
}
